2023.6.27 NAG最適化について
NAGそのものの答えでは無いのですが、次のように損失関数をひとつにまとめると、見通しが良くなると思いました。
code:python
import torch
def func(arg):
W, H = arg
criterion = (2*W@H).sum()
return criterion
W = torch.rand(3,2).requires_grad_(True)
H = torch.rand(2,4).requires_grad_(True)
criterion = func(W, H) # まとめて渡して損失関数を1つにまとまる criterion.backward()
print('W.grad :\n',W.grad)
print('H.grad :\n', H.grad)
optimizerはW, Hそれぞれ独立して用意する必要がありそう。
また、$ \nabla_* {\cal L}(m+n)については、似たような考え方はルンゲ-クッタ法にもあり、
code:python
def func(t, y):
....
として定義した関数に対して変更を加えたt, yを与えて処理を行っている。